#!/usr/bin/env python3

import os
import re
import sys
import socket
import time
import subprocess
import argparse
import logging
from ctypes import CDLL, create_string_buffer

from nicelogger import enable_pretty_logging
enable_pretty_logging('DEBUG')

udp_server = ('xmpp.vim-cn.com', 2727)

def disconnect(sock):
  libc = CDLL("libc.so.6")
  buf = create_string_buffer(16) # sizeof struct sockaddr_in
  libc.connect(sock.fileno(), buf, 16)

# For the address variable names, pe is 'peer external', sl is 'self local', etc

def main(server, prog, name):
  sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
  sock.settimeout(2)

  # connect for self local ip
  sock.connect(udp_server)
  logging.info('registering my name %s and finding peer...', name)
  pe_ip = None
  sock.send(b'reg ' + name.encode('utf-8'))
  for i in range(1, 30):
    try:
      msg, addr = sock.recvfrom(1024)
      _, *items = msg.decode('utf-8', errors='replace').split()

      logging.info('got addresses: %s', items)
      if len(items) == 2:
        se_ip, se_port = items[0], int(items[1])
      elif len(items) == 4:
        se_ip, se_port = items[0], int(items[1])
        pe_ip, pe_port = items[2], int(items[3])
        break
    except socket.timeout:
      sock.send(b'reg ' + name.encode('utf-8'))
  if pe_ip is None:
    logging.fatal("failed to get peer's address.")
    sys.exit(2)

  sl_ip, sl_port = sock.getsockname()
  logging.info('Got my IP and Port: (%r, %s), (%r, %s).', se_ip, se_port, sl_ip, sl_port)

  # be optimisctic, works even when peer is symmetric
  dst = pe_ip, pe_port
  disconnect(sock)
  # sock.connect((pe_ip, pe_port))
  logging.info('send and receive local ip...')
  pl_ip = pl_port = None
  for i in range(1, 5):
    # TODO: require ACK before stop
    try:
      sock.sendto(('%s %d' % (sl_ip, sl_port)).encode('utf-8'), dst)
      msg, addr = sock.recvfrom(1024)
      logging.debug('recv from %r: %s', addr, msg)
      msg = msg.decode('utf-8', errors='replace')
      pl_ip, pl_port = msg.split()
      pl_port = int(pl_port)
    except (socket.timeout, ValueError):
      pass
    except ConnectionRefusedError:
      # peer not ready
      time.sleep(2)
  if pl_ip is None:
    logging.warn("failed to get peer's address, continue anyway.")

  if server:
    globals()['sprog_' + prog](sock, sl_port, pl_ip, pl_port, se_ip, se_port, pe_ip, pe_port)
  else:
    globals()['cprog_' + prog](sock, sl_port, pl_ip, pl_port, se_ip, se_port, pe_ip, pe_port)

def sprog_mosh(sock, port, pl_ip, pl_port, se_ip, se_port, pe_ip, pe_port):
  sock.close()
  logging.info('Starting mosh server...')
  msg = subprocess.check_output(['mosh-server', 'new', '-p', str(port)])
  secret = msg.split()[3].decode()
  print('Connect with:\nMOSH_KEY=%s MOSH_CPORT=%s mosh-client-lily %s %s' % (secret, pl_port, se_ip, se_port))

def cprog_mosh(*args):
  logging.info('done.')

def sprog_openvpn(sock, port, pl_ip, pl_port, se_ip, se_port, pe_ip, pe_port):
  stmpl = '''\
dev tun
proto udp
port {port}
mssfix 1400
keepalive 10 60

ca /etc/openvpn/easy-rsa/keys/ca.crt
cert /etc/openvpn/easy-rsa/keys/server.crt
key /etc/openvpn/easy-rsa/keys/server.key
dh /etc/openvpn/easy-rsa/keys/dh1024.pem

user nobody
group nobody
server 10.7.0.0 255.255.255.0

persist-key
persist-tun

client-to-client
comp-lzo
'''
  ctmpl = '''\
client
max-routes 2048
dev tun
remote {se_ip} {se_port} udp
resolv-retry infinite
local {pl_ip}
port {pl_port}
mssfix 1400
keepalive 10 60
persist-key
persist-tun
ns-cert-type server
comp-lzo
verb 3
route-nopull
route 10.7.0.0 255.255.255.0
'''
  with open('hole_s.ovpn', 'w') as f:
    s = stmpl.format(port=port)
    f.write(s)
  sconfig = os.path.join(os.getcwd(), 'hole_s.ovpn')

  if pl_ip is not None:
    s = ctmpl.format(
      se_ip=se_ip, se_port=se_port,
      pl_ip=pl_ip, pl_port=pl_port
    ).encode()
    logging.info('sending client-side config to %s:%d...', pe_ip, pe_port)

    sock.settimeout(1)
    sock.connect((pe_ip, pe_port))
    for i in range(5):
      try:
        sock.sendto(s, (pe_ip, pe_port))
        msg, addr = sock.recvfrom(1024)
        if msg == b'ok':
          logging.info('client-side config sent. starting openvpn')
          break
      except socket.timeout:
        pass
    else:
      logging.warn('ACK not received, starting anyway...')
  sock.close()

  p = subprocess.Popen(['sudo', '/usr/bin/openvpn', '--cd', '/etc/openvpn', '--config', sconfig])
  try:
    p.wait()
  except KeyboardInterrupt:
    p.wait()

def cprog_openvpn(sock, port, pl_ip, pl_port, se_ip, se_port, pe_ip, pe_port):
  if pl_ip:
    logging.info('receiving config')
    sock.settimeout(20)
    sock.connect((pe_ip, pe_port))
    while True:
      msg = sock.recv(1024)
      if not msg.strip().startswith(b'client'):
        continue
      sock.send(b'ok')
      sock.close()
      break
    msg = msg.decode()
    logging.info('config received as hole.ovpn')
  else:
    # try connecting with any address, hoping peer side will pass it through
    msg = '''\
client
max-routes 2048
dev tun
remote {pe_ip} {pe_port} udp
resolv-retry infinite
mssfix 1400
keepalive 10 60
persist-key
persist-tun
ns-cert-type server
comp-lzo
verb 3
route-nopull
route 10.7.0.0 255.255.255.0
'''.format(pe_ip = pe_ip, pe_port = pe_port)
  with open('hole.ovpn', 'w') as f:
    f.write(msg)
    auth = args.auth.read()
    f.write(auth)
  os.execvp('sudo', ['sudo', 'openvpn', 'hole.ovpn'])

if __name__ == '__main__':
  parser = argparse.ArgumentParser(description='permform UDP hole punching and run some program')
  parser.add_argument('-s', '--server', action='store_true', default=False,
                      help='server-side')
  parser.add_argument('-n', '--name', required=True,
                      help="set a name to identity peers. DON'T use same one in a few minutes!")
  parser.add_argument('prog', metavar='PROG', default='mosh', nargs='?',
                      help='program to run via the hole')
  parser.add_argument('-a', '--auth', type=open,
                      help='OpenVPN auth info file (client)')
  args = parser.parse_args()

  if 'sprog_' + args.prog not in globals():
    sys.exit("don't know how to cope with program %r" % args.prog)

  if args.prog == 'openvpn' and not args.server and not args.auth:
    sys.exit('Auth file should be provided for OpenVPN client')

  main(args.server, args.prog, args.name)
